
import oneflow
from kamal.core.metrics.stream_metrics import Metric

__all__=['Accuracy']

class Accuracy(Metric):
    def __init__(self, attach_to=None):
        super(Accuracy, self).__init__(attach_to=attach_to)
        self.reset()

    @oneflow.no_grad()
    def update(self, outputs, targets):
        outputs, targets = self._attach(outputs, targets)
        outputs = outputs.max(1)[1]
        self._correct += ( outputs.view(-1)==targets.view(-1) ).sum()
        self._cnt += oneflow.numel( targets )

    def get_results(self):
        return (self._correct / self._cnt).detach().cpu()
    
    def reset(self):
        self._correct = self._cnt = 0.0
